{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 翻訳 (TensorFlow)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to setup git, adapt your email and name in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets = load_dataset(\"kde4\", lang1=\"en\", lang2=\"fr\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'translation'],\n",
       "        num_rows: 210173\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'translation'],\n",
       "        num_rows: 189155\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'translation'],\n",
       "        num_rows: 21018\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_datasets = raw_datasets[\"train\"].train_test_split(train_size=0.9, seed=20)\n",
    "split_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_datasets[\"validation\"] = split_datasets.pop(\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'en': 'Default to expanded threads',\n",
       " 'fr': 'Par défaut, développer les fils de discussion'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_datasets[\"train\"][1][\"translation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': 'Par défaut pour les threads élargis'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n",
    "translator = pipeline(\"translation\", model=model_checkpoint)\n",
    "translator(\"Default to expanded threads\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'en': 'Unable to import %1 using the OFX importer plugin. This file is not the correct format.',\n",
       " 'fr': \"Impossible d'importer %1 en utilisant le module d'extension d'importation OFX. Ce fichier n'a pas un format correct.\"}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "split_datasets[\"train\"][172][\"translation\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': \"Impossible d'importer %1 en utilisant le plugin d'importateur OFX. Ce fichier n'est pas le bon format.\"}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translator(\n",
    "    \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "en_sentence = split_datasets[\"train\"][1][\"translation\"][\"en\"]\n",
    "fr_sentence = split_datasets[\"train\"][1][\"translation\"][\"fr\"]\n",
    "\n",
    "inputs = tokenizer(en_sentence)\n",
    "with tokenizer.as_target_tokenizer():\n",
    "    targets = tokenizer(fr_sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁Par', '▁dé', 'f', 'aut', ',', '▁dé', 've', 'lop', 'per', '▁les', '▁fil', 's', '▁de', '▁discussion', '</s>']\n",
       "['▁Par', '▁défaut', ',', '▁développer', '▁les', '▁fils', '▁de', '▁discussion', '</s>']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wrong_targets = tokenizer(fr_sentence)\n",
    "print(tokenizer.convert_ids_to_tokens(wrong_targets[\"input_ids\"]))\n",
    "print(tokenizer.convert_ids_to_tokens(targets[\"input_ids\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_input_length = 128\n",
    "max_target_length = 128\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = [ex[\"en\"] for ex in examples[\"translation\"]]\n",
    "    targets = [ex[\"fr\"] for ex in examples[\"translation\"]]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n",
    "\n",
    "    # Set up the tokenizer for targets\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        labels = tokenizer(targets, max_length=max_target_length, truncation=True)\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets = split_datasets.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    remove_columns=split_datasets[\"train\"].column_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForSeq2SeqLM\n",
    "\n",
    "model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_pt=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForSeq2Seq\n",
    "\n",
    "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['attention_mask', 'input_ids', 'labels', 'decoder_input_ids'])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = data_collator([tokenized_datasets[\"train\"][i] for i in range(1, 3)])\n",
    "batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  577,  5891,     2,  3184,    16,  2542,     5,  1710,     0,  -100,\n",
       "          -100,  -100,  -100,  -100,  -100,  -100],\n",
       "        [ 1211,     3,    49,  9409,  1211,     3, 29140,   817,  3124,   817,\n",
       "           550,  7032,  5821,  7907, 12649,     0]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[59513,   577,  5891,     2,  3184,    16,  2542,     5,  1710,     0,\n",
       "         59513, 59513, 59513, 59513, 59513, 59513],\n",
       "        [59513,  1211,     3,    49,  9409,  1211,     3, 29140,   817,  3124,\n",
       "           817,   550,  7032,  5821,  7907, 12649]])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"decoder_input_ids\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[577, 5891, 2, 3184, 16, 2542, 5, 1710, 0]\n",
       "[1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124, 817, 550, 7032, 5821, 7907, 12649, 0]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(1, 3):\n",
    "    print(tokenized_datasets[\"train\"][i][\"labels\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_train_dataset = tokenized_datasets[\"train\"].to_tf_dataset(\n",
    "    columns=[\"input_ids\", \"attention_mask\", \"labels\"],\n",
    "    collate_fn=data_collator,\n",
    "    shuffle=True,\n",
    "    batch_size=32,\n",
    ")\n",
    "tf_eval_dataset = tokenized_datasets[\"validation\"].to_tf_dataset(\n",
    "    columns=[\"input_ids\", \"attention_mask\", \"labels\"],\n",
    "    collate_fn=data_collator,\n",
    "    shuffle=False,\n",
    "    batch_size=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install sacrebleu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"sacrebleu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 46.750469682990165,\n",
       " 'counts': [11, 6, 4, 3],\n",
       " 'totals': [12, 11, 10, 9],\n",
       " 'precisions': [91.67, 54.54, 40.0, 33.33],\n",
       " 'bp': 0.9200444146293233,\n",
       " 'sys_len': 12,\n",
       " 'ref_len': 13}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = [\n",
    "    \"This plugin lets you translate web pages between several languages automatically.\"\n",
    "]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 1.683602693167689,\n",
       " 'counts': [1, 0, 0, 0],\n",
       " 'totals': [4, 3, 2, 1],\n",
       " 'precisions': [25.0, 16.67, 12.5, 12.5],\n",
       " 'bp': 0.10539922456186433,\n",
       " 'sys_len': 4,\n",
       " 'ref_len': 13}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = [\"This This This This\"]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 0.0,\n",
       " 'counts': [2, 1, 0, 0],\n",
       " 'totals': [2, 1, 0, 0],\n",
       " 'precisions': [100.0, 100.0, 0.0, 0.0],\n",
       " 'bp': 0.004086771438464067,\n",
       " 'sys_len': 2,\n",
       " 'ref_len': 13}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = [\"This plugin\"]\n",
    "references = [\n",
    "    [\n",
    "        \"This plugin allows you to automatically translate web pages between several languages.\"\n",
    "    ]\n",
    "]\n",
    "metric.compute(predictions=predictions, references=references)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics():\n",
    "    all_preds = []\n",
    "    all_labels = []\n",
    "    sampled_dataset = tokenized_datasets[\"validation\"].shuffle().select(range(200))\n",
    "    tf_generate_dataset = sampled_dataset.to_tf_dataset(\n",
    "        columns=[\"input_ids\", \"attention_mask\", \"labels\"],\n",
    "        collate_fn=data_collator,\n",
    "        shuffle=False,\n",
    "        batch_size=4,\n",
    "    )\n",
    "    for batch in tf_generate_dataset:\n",
    "        predictions = model.generate(\n",
    "            input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"]\n",
    "        )\n",
    "        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
    "        labels = batch[\"labels\"].numpy()\n",
    "        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "        decoded_preds = [pred.strip() for pred in decoded_preds]\n",
    "        decoded_labels = [[label.strip()] for label in decoded_labels]\n",
    "        all_preds.extend(decoded_preds)\n",
    "        all_labels.extend(decoded_labels)\n",
    "\n",
    "    result = metric.compute(predictions=all_preds, references=all_labels)\n",
    "    return {\"bleu\": result[\"score\"]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(compute_metrics())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import create_optimizer\n",
    "from transformers.keras_callbacks import PushToHubCallback\n",
    "import tensorflow as tf\n",
    "\n",
    "# The number of training steps is the number of samples in the dataset, divided by the batch size then multiplied\n",
    "# by the total number of epochs. Note that the tf_train_dataset here is a batched tf.data.Dataset,\n",
    "# not the original Hugging Face Dataset, so its len() is already num_samples // batch_size.\n",
    "num_epochs = 3\n",
    "num_train_steps = len(tf_train_dataset) * num_epochs\n",
    "\n",
    "optimizer, schedule = create_optimizer(\n",
    "    init_lr=5e-5,\n",
    "    num_warmup_steps=0,\n",
    "    num_train_steps=num_train_steps,\n",
    "    weight_decay_rate=0.01,\n",
    ")\n",
    "model.compile(optimizer=optimizer)\n",
    "\n",
    "# Train in mixed-precision float16\n",
    "tf.keras.mixed_precision.set_global_policy(\"mixed_float16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.keras_callbacks import PushToHubCallback\n",
    "\n",
    "callback = PushToHubCallback(\n",
    "    output_dir=\"marian-finetuned-kde4-en-to-fr\", tokenizer=tokenizer\n",
    ")\n",
    "\n",
    "model.fit(\n",
    "    tf_train_dataset,\n",
    "    validation_data=tf_eval_dataset,\n",
    "    callbacks=[callback],\n",
    "    epochs=num_epochs,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(compute_metrics())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': 'Par défaut, développer les fils de discussion'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Replace this with your own checkpoint\n",
    "model_checkpoint = \"huggingface-course/marian-finetuned-kde4-en-to-fr\"\n",
    "translator = pipeline(\"translation\", model=model_checkpoint)\n",
    "translator(\"Default to expanded threads\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'translation_text': \"Impossible d'importer %1 en utilisant le module externe d'importation OFX. Ce fichier n'est pas le bon format.\"}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translator(\n",
    "    \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "翻訳 (TensorFlow)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
